% size(mat_K) = N x K
% size(mat_M) = N x M
% the output matrix mat_KM will be N x K*M
% mat_KM(jj,:) = [mat_M(jj,1)*mat_K(jj,:)...
%                 mat_M(jj,2)*mat_K(jj,:)...
%                 ...
%                 mat_M(jj,M)*mat_K(jj,:)];

function mat_KM = expand_agg_states(mat_K, mat_M)

N = size(mat_K,1);

K = size(mat_K,2);
M = size(mat_M,2);

mat_KM = zeros(N, K*M); 
for mm=1:M
    col1 = 1+K*(mm-1);
    col2 = K*mm;
    mat_KM(:, col1:col2) = mat_K.*repmat(mat_M(:,mm), 1, K);
end